{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8cc82b48",
   "metadata": {},
   "source": [
    "# How to use the MultiQueryRetriever\n",
    "\n",
    "Distance-based vector database retrieval embeds (represents) queries in high-dimensional space and finds similar embedded documents based on a distance metric. But, retrieval may produce different results with subtle changes in query wording, or if the embeddings do not capture the semantics of the data well. Prompt engineering / tuning is sometimes done to manually address these problems, but can be tedious.\n",
    "\n",
    "The [MultiQueryRetriever](https://api.python.langchain.com/en/latest/retrievers/langchain.retrievers.multi_query.MultiQueryRetriever.html) automates the process of prompt tuning by using an LLM to generate multiple queries from different perspectives for a given user input query. For each query, it retrieves a set of relevant documents and takes the unique union across all queries to get a larger set of potentially relevant documents. By generating multiple perspectives on the same question, the `MultiQueryRetriever` can mitigate some of the limitations of the distance-based retrieval and get a richer set of results.\n",
    "\n",
    "Let's build a vectorstore using the [LLM Powered Autonomous Agents](https://lilianweng.github.io/posts/2023-06-23-agent/) blog post by Lilian Weng from the [RAG tutorial](/docs/tutorials/rag):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "994d6c74",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build a sample vectorDB\n",
    "from langchain_chroma import Chroma\n",
    "from langchain_community.document_loaders import WebBaseLoader\n",
    "from langchain_openai import OpenAIEmbeddings\n",
    "from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
    "\n",
    "# Load blog post\n",
    "loader = WebBaseLoader(\"https://lilianweng.github.io/posts/2023-06-23-agent/\")\n",
    "data = loader.load()\n",
    "\n",
    "# Split\n",
    "text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)\n",
    "splits = text_splitter.split_documents(data)\n",
    "\n",
    "# VectorDB\n",
    "embedding = OpenAIEmbeddings()\n",
    "vectordb = Chroma.from_documents(documents=splits, embedding=embedding)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cca8f56c",
   "metadata": {},
   "source": [
    "#### Simple usage\n",
    "\n",
    "Specify the LLM to use for query generation, and the retriever will do the rest."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "edbca101",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.retrievers.multi_query import MultiQueryRetriever\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "question = \"What are the approaches to Task Decomposition?\"\n",
    "llm = ChatOpenAI(temperature=0)\n",
    "retriever_from_llm = MultiQueryRetriever.from_llm(\n",
    "    retriever=vectordb.as_retriever(), llm=llm\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9e6d3b69",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set logging for the queries\n",
    "import logging\n",
    "\n",
    "logging.basicConfig()\n",
    "logging.getLogger(\"langchain.retrievers.multi_query\").setLevel(logging.INFO)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bc93dc2b-9407-48b0-9f9a-338247e7eb69",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:langchain.retrievers.multi_query:Generated queries: ['1. How can Task Decomposition be achieved through different methods?', '2. What strategies are commonly used for Task Decomposition?', '3. What are the various techniques for breaking down tasks in Task Decomposition?']\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "5"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "unique_docs = retriever_from_llm.invoke(question)\n",
    "len(unique_docs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e170263-facd-4065-bb68-d11fb9123a45",
   "metadata": {},
   "source": [
    "Note that the underlying queries generated by the retriever are logged at the `INFO` level."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c54a282f",
   "metadata": {},
   "source": [
    "#### Supplying your own prompt\n",
    "\n",
    "Under the hood, `MultiQueryRetriever` generates queries using a specific [prompt](https://api.python.langchain.com/en/latest/_modules/langchain/retrievers/multi_query.html#MultiQueryRetriever). To customize this prompt:\n",
    "\n",
    "1. Make a [PromptTemplate](https://api.python.langchain.com/en/latest/prompts/langchain_core.prompts.prompt.PromptTemplate.html) with an input variable for the question;\n",
    "2. Implement an [output parser](/docs/concepts#output-parsers) like the one below to split the result into a list of queries.\n",
    "\n",
    "The prompt and output parser together must support the generation of a list of queries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d9afb0ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "\n",
    "from langchain_core.output_parsers import BaseOutputParser\n",
    "from langchain_core.prompts import PromptTemplate\n",
    "from langchain_core.pydantic_v1 import BaseModel, Field\n",
    "\n",
    "\n",
    "# Output parser will split the LLM result into a list of queries\n",
    "class LineListOutputParser(BaseOutputParser[List[str]]):\n",
    "    \"\"\"Output parser for a list of lines.\"\"\"\n",
    "\n",
    "    def parse(self, text: str) -> List[str]:\n",
    "        lines = text.strip().split(\"\\n\")\n",
    "        return lines\n",
    "\n",
    "\n",
    "output_parser = LineListOutputParser()\n",
    "\n",
    "QUERY_PROMPT = PromptTemplate(\n",
    "    input_variables=[\"question\"],\n",
    "    template=\"\"\"You are an AI language model assistant. Your task is to generate five \n",
    "    different versions of the given user question to retrieve relevant documents from a vector \n",
    "    database. By generating multiple perspectives on the user question, your goal is to help\n",
    "    the user overcome some of the limitations of the distance-based similarity search. \n",
    "    Provide these alternative questions separated by newlines.\n",
    "    Original question: {question}\"\"\",\n",
    ")\n",
    "llm = ChatOpenAI(temperature=0)\n",
    "\n",
    "# Chain\n",
    "llm_chain = QUERY_PROMPT | llm | output_parser\n",
    "\n",
    "# Other inputs\n",
    "question = \"What are the approaches to Task Decomposition?\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "59c75c56-dbd7-4887-b9ba-0b5b21069f51",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:langchain.retrievers.multi_query:Generated queries: ['1. Can you provide insights on regression from the course material?', '2. How is regression discussed in the course content?', '3. What information does the course offer about regression?', '4. In what way is regression covered in the course?', '5. What are the teachings of the course regarding regression?']\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "9"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Run\n",
    "retriever = MultiQueryRetriever(\n",
    "    retriever=vectordb.as_retriever(), llm_chain=llm_chain, parser_key=\"lines\"\n",
    ")  # \"lines\" is the key (attribute name) of the parsed output\n",
    "\n",
    "# Results\n",
    "unique_docs = retriever.invoke(\"What does the course say about regression?\")\n",
    "len(unique_docs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
